{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluating Detectors\n", "\n", "In `scikit-clean`, A `Detector` only identifies/detects the mislabelled samples. It's not a complete classifier (rather a part of one). So procedure for their evaluation is different.\n", "\n", "We can view a noise detector as a binary classifier: it's job is to provide a probability denoting if a sample is \"mislabelled\" or \"clean\". We can therefore use binary classification metrics that work on continuous output: brier score, log loss, area under ROC curve etc." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Suppress warnings, you should remove this before modifying this notebook\n", "def warn(*args, **kwargs):\n", " pass\n", "import warnings\n", "warnings.warn = warn\n", "\n", "import numpy as np\n", "import pandas as pd\n", "from sklearn.datasets import make_classification\n", "from sklearn.metrics import brier_score_loss, log_loss, roc_auc_score\n", "\n", "from skclean.tests.common_stuff import NOISE_DETECTORS # All noise detectors in skclean\n", "from skclean.utils import load_data \n", "from skclean.detectors.base import BaseDetector\n", "from skclean.simulate_noise import flip_labels_uniform" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class DummyDetector(BaseDetector):\n", " def detect(self, X, y):\n", " return np.random.uniform(size=y.shape)\n", "\n", "from skclean.detectors import KDN, RkDN\n", "class WkDN:\n", " def detect(self,X,y):\n", " return .5 * KDN().detect(X,y) + .5 * RkDN().detect(X,y)\n", " \n", "ALL_DETECTOTS = [DummyDetector(), WkDN()] + NOISE_DETECTORS" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "X, y = make_classification(1800, 10)\n", "#X, y = load_data('breast_cancer')\n", "\n", "yn = flip_labels_uniform(y, .3) # 30% label noise\n", "clean_idx = (y==yn) # Indices of correctly labelled samples" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
logbrierroc
DummyDetector0.9990.3330.501
WkDN0.6640.1830.811
ForestKDN1.0990.1310.858
InstanceHardness0.4480.1410.902
KDN0.8300.1730.818
RkDN3.3710.2270.749
MCS0.2940.0710.955
PartitioningDetector0.9420.0720.950
RandomForestDetector0.4640.1450.908
\n", "
" ], "text/plain": [ " log brier roc\n", "DummyDetector 0.999 0.333 0.501\n", "WkDN 0.664 0.183 0.811\n", "ForestKDN 1.099 0.131 0.858\n", "InstanceHardness 0.448 0.141 0.902\n", "KDN 0.830 0.173 0.818\n", "RkDN 3.371 0.227 0.749\n", "MCS 0.294 0.071 0.955\n", "PartitioningDetector 0.942 0.072 0.950\n", "RandomForestDetector 0.464 0.145 0.908" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.DataFrame()\n", "for d in ALL_DETECTOTS:\n", " conf_score = d.detect(X, yn)\n", " for name,loss_func in zip(['log','brier','roc'],\n", " [log_loss, brier_score_loss, roc_auc_score]):\n", " loss = loss_func(clean_idx, conf_score)\n", " df.at[d.__class__.__name__,name] = np.round(loss,3)\n", "df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that in case of `roc_auc_score`, higher is better." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }